-
Notifications
You must be signed in to change notification settings - Fork 25.6k
[ML] Adding asynchronous start up logic for the inference API internals #135462
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ML] Adding asynchronous start up logic for the inference API internals #135462
Conversation
| } | ||
|
|
||
| /** | ||
| * TODO implement this functionality to ensure that we don't block node bootups |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Converting bedrock is going to take a little more work. Probably best to do this in a separate PR because this one is already 50 files 😬
|
|
||
| var requestMetadata = extractRequestMetadataFromThreadContext(threadPool.getThreadContext()); | ||
| var request = new ElasticInferenceServiceAuthorizationRequest(baseUrl, getCurrentTraceInfo(), requestMetadata); | ||
| SubscribableListener.newForked(sender::startAsynchronously).<InferenceServiceResults>andThen((authListener) -> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now we're doing an async start and then once that completes we do the rest of the functionality as normal.
…asticsearch into ml-async-sender-init
|
Pinging @elastic/ml-core (Team:ML) |
| @Override | ||
| public void startSynchronously() { | ||
| if (started.compareAndSet(false, true)) { | ||
| startInternal(ActionListener.noop()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will this cause any exception thrown in startInternal() to be ignored when doing a synchronous start? Also, do we need to make sure that we always call waitForStartToComplete() before returning from this method? If someone calls startAsynchronously() then another thread immediately calls startSynchronously(), the second call will return immediately (because we already set started to true) but the sender won't actually have started yet.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, I'll make those changes.
| init(); | ||
| doStart(model, listener); | ||
| SubscribableListener.newForked(this::init) | ||
| .<Boolean>andThen((doStartListener) -> doStart(model, doStartListener)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the purpose of calling doStart() here? It seems to be a no-op that just immediately returns.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The idea is that it can be overridden by child classes. In reality I don't think any actually override it yet. The Elasticsearch integration does use it but that doesn't extend from SenderService.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gotcha, thanks for the explanation
| sender.startSynchronously(); | ||
| sender.startSynchronously(); | ||
| sender.startSynchronously(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be good to add some tests for the startAsynchronously() method, since it's a distinct implementation from startSynchronously(). Also, a test that calling startAsynchronously() followed immediately by startSynchronously() behaves the way we expect would be good.
|
|
||
| @Override | ||
| public void startAsynchronously(ActionListener<Void> listener) { | ||
| listener.onResponse(null); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be better to have this method throw the same UnsupportedOperationException as AmazonBedrockRequestSender.startAsynchronously()? It would be nice to have a little extra confidence that we're not calling an unsupported method.
| @Override | ||
| public void start() { | ||
| public void startAsynchronously(ActionListener<Void> listener) { | ||
| throw new UnsupportedOperationException("not implemented"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be worth wrapping this throw in a check on the value of started? If the sender has already been started, then calling startAsynchronously() should have no effect.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm I think in that situation we should still throw. It would be a bug if we're ever calling that method for AmazonBedrockRequestSender.
| } | ||
|
|
||
| @SuppressWarnings("unchecked") | ||
| public void testGetAuthorization_OnResponseCalledOnce() throws IOException { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why was this test deleted?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah sorry I meant to comment on this and forgot 😅 . I'll add it back, it was giving me problems because we're mocking the listener but I think I found a way to fix it.
| // Checking for both exception types because there's a race condition between the Error being thrown on a separate thread | ||
| // and the startCompleted latch timing out waiting for the start to complete |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
HttpRequestSender.startInternal(), only catches and handles Exception, so any Error thrown in that method will always escape and cause the listener to not be invoked, meaning that the maybeDieOnAnotherThread() call never happens, and neither does the waitForStartToComplete() call in startSynchronously(), so we wouldn't ever expect to see the IllegalStateException get thrown from waitForStartToComplete().
If I change to test to use an IllegalArgumentException wrapping an Error, then the listener is invoked and we always get the IllegalStateException thrown from startSynchronously() due to timing out waiting for the sender to start. However, with that change, the test fails due to the error being thrown in another thread. I don't know how to tell a test to expect an exception to be thrown in another thread, but it looks like CloseFollowerIndexIT.wrapUncaughtExceptionHandler() might be trying to solve the same problem.
I wonder if we need to rethrow the Error at all in the case where we catch an Exception with an Error as one of its causes, or just log it and allow the waitForStartToComplete() call to inevitably time out?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch.
I wonder if we need to rethrow the Error at all in the case where we catch an Exception with an Error as one of its causes, or just log it and allow the waitForStartToComplete() call to inevitably time out?
Yeah I think I'm going to just log it and rely on the waitForStartToComplete(). After we refactor bedrock, I'm pretty sure we can remove the startSynchronously() all together or just use it for tests.
| } | ||
| } | ||
|
|
||
| public void testCreateSender_CanCallStartAsyncMultipleTimes() throws Exception { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test and the one below it could be improved a little by verifying that no matter how many times we call startAsynchronously() or startSynchronously(), we only call HttpClientManager.start() once:
var clientManagerSpy = spy(clientManager);
var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), clientManagerSpy, mockClusterServiceEmpty());
...
for (int i = 0; i < asyncCalls; i++) {
PlainActionFuture<Void> listener = listenerList.get(i);
assertNull(listener.actionGet(TIMEOUT));
}
verify(clientManagerSpy, times(1)).start();
It would also be nice if we could verify that we're calling waitForStartToComplete() the expected number of times.
| // Handle the case where start*() was already called and this would return immediately because the started flag is already true | ||
| waitForStartToComplete(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm wondering if we need to do something similar for async calls, since if two async calls come in one after the other, the second one will complete immediately even if the first one hasn't finished starting the sender yet.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea, I tried to come up with a solution that would avoid having to do spin up a thread to then call the waitForStartToComplete since most of the time it will simply return.
This PR adds functionality to allow for an asynchronous version of the startup logic within the inference API. We haven't yet seen problems with doing this synchronously. Doing it async makes changes to the dynamic preconfigured inference endpoints changes a little bit easier.
This PR also adds
SubscribableListenerin a few places to make the flow easier since we're relying on a listener for the applicable methods instead of blocking.